# -*- coding: utf-8 -*-
# @Time : 2021/7/26 下午4:44
# @Author : fugang_le

from configs.classification_config import FACILITY_RESOURCE_POOL, MATERIAL_RESOURCE_POOL
from src.classification.material.textcnn.predict import CnnModel as material_model
from src.classification.facility.textcnn.predict import CnnModel as facility_model


class Classification:
    def __init__(self):
        self.material_model = material_model()
        self.facility_model = facility_model()

    def predict(self, price_type, data, topk=1):
        input_data = [item['material_name'] + item.get('material_sepc', '') for item in data]

        if price_type == FACILITY_RESOURCE_POOL:
            pred_result = self.facility_model.predict(input_data, topk)
        elif price_type == MATERIAL_RESOURCE_POOL:
            pred_result = self.material_model.predict(input_data, topk)
        else:
            pred_result = []
        return pred_result

